{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_"
      },
      "outputs": [],
      "source": [
        "import haiku as hk\n",
        "import jax\n",
        "import jax.numpy as jnp"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7ag4NOGwC3lR"
      },
      "source": [
        "**TL;DR:** A JAX transform inside of a `hk.transform` is likely to transform a side effecting function, which will result in an `UnexpectedTracerError`. This page describes two ways to get around this."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kXvm0ueZC3lU"
      },
      "source": [
        "# Limitations of Nesting JAX Functions and Haiku Modules\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XKDdO7n_C3lU"
      },
      "source": [
        "\n",
        "\n",
        "Once a Haiku network has been transformed to a pair of pure functions using `hk.transform`, it's possible to freely combine these with any JAX transformations like `jax.jit`, `jax.grad`, `jax.lax.scan` and so on.\n",
        "\n",
        "If you want to use JAX transformations **inside** of a `hk.transform` however, you need to be more careful. It's possible, but most functions inside of the `hk.transform` boundary are still side effecting, and cannot safely be transformed by JAX.\n",
        "This is a common cause of `UnexpectedTracerError`s in code using Haiku. These errors are a result of using a JAX transform on a side effecting function  (for more information on this JAX error, see https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError).\n",
        "\n",
        "An example with `jax.eval_shape`:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CMnZLQgPC3lV",
        "outputId": "2305be6d-2976-4f0b-a4f7-eb999a46c16d"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "UnexpectedTracerError: applied JAX transform to side effecting function\n"
          ]
        }
      ],
      "source": [
        "def net(x): # inside of a hk.transform, this is still side-effecting\n",
        "  w = hk.get_parameter(\"w\", (2, 2), init=jnp.ones)\n",
        "  return w @ x\n",
        "\n",
        "def eval_shape_net(x):\n",
        "  output_shape = jax.eval_shape(net, x) # eval_shape on side-effecting function\n",
        "  return net(x)                         # UnexpectedTracerError!\n",
        "\n",
        "init, _ = hk.transform(eval_shape_net)\n",
        "try:\n",
        "  init(jax.random.PRNGKey(666), jnp.ones((2, 2)))\n",
        "except jax.errors.UnexpectedTracerError:\n",
        "  print(\"UnexpectedTracerError: applied JAX transform to side effecting function\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pZrXKtC0C3lX"
      },
      "source": [
        "These examples use `jax.eval_shape`, but could have used any higher-order JAX function (eg. `jax.vmap`, `jax.lax.scan`, \n",
        "`jax.while_loop`, ...).\n",
        "\n",
        "The error points to `hk.get_parameter`. This is the operation which makes `net` a side effecting function. The side effect in this case is the creation of a parameter, which gets stored into the Haiku state. Similarly you would get an error using `hk.next_rng_key`, because it advances the Haiku RNG state and stores a new PRNGKey into the Haiku state. In general, transforming a non-transformed Haiku module will result in an `UnexpectedTracerError`.\n",
        "\n",
        "You could re-write the code above to create the parameter outside of the `eval_shape` transformation, making `net` a pure function by threading through the parameter explictly as an argument:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BG68aS9lC3lY",
        "outputId": "3fe34887-cb81-499c-8dda-e8f8f818e4a3"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "(ShapeDtypeStruct(shape=(3, 3), dtype=float32),\n",
              " DeviceArray([[2., 2., 2.],\n",
              "              [2., 2., 2.],\n",
              "              [2., 2., 2.]], dtype=float32))"
            ]
          },
          "execution_count": 3,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "def net(w, x): # no side effects!\n",
        "  return w @ x\n",
        "\n",
        "def eval_shape_net(x):\n",
        "  w = hk.get_parameter(\"w\", (3, 2), init=jnp.ones)\n",
        "  output_shape = jax.eval_shape(net, w, x) # net is now side-effect free\n",
        "  return output_shape, net(w, x)\n",
        "\n",
        "key = jax.random.PRNGKey(777)\n",
        "x = jnp.ones((2, 3))\n",
        "init, apply = hk.transform(eval_shape_net)\n",
        "params = init(key, x)\n",
        "apply(params, key, x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FfVWrRg9C3lZ"
      },
      "source": [
        "However, that's not always possible. Consider the following code which calls a Haiku module (`hk.nets.MLP`) which we don't own. This module will internally call `get_parameter`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iBJ-JpDYC3la",
        "outputId": "6bbb43bf-bfc3-4b7e-fdd9-e471c627be32"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "UnexpectedTracerError: applied JAX transform to side effecting function\n"
          ]
        }
      ],
      "source": [
        "def eval_shape_net(x):\n",
        "  net = hk.nets.MLP([300, 100])\n",
        "  output_shape = jax.eval_shape(net, x)\n",
        "  return output_shape, net(x)\n",
        "\n",
        "init, _ = hk.transform(eval_shape_net)\n",
        "try:\n",
        "  init(jax.random.PRNGKey(666), jnp.ones((2, 2)))\n",
        "except jax.errors.UnexpectedTracerError:\n",
        "  print(\"UnexpectedTracerError: applied JAX transform to side effecting function\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e04UC5ZbC3lb"
      },
      "source": [
        "## Using hk.lift\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O7E0ldV7C3lc"
      },
      "source": [
        "We want a way to get access to our implicit Haiku state, and\n",
        "get a functionally pure version of `hk.nets.MLP`. The way to usually achieve this is by using a `hk.transform`, so all we need is a way to nest an inner `hk.tranform` inside an outer `hk.transform`! We'll create another pair of `init` and `apply` functions through `hk.transform`, and these can then be safely combined with any higher-order JAX function.\n",
        "\n",
        "However, we need a way to register this nested `hk.tranform` state into the outer scope. We can use `hk.lift` for this. Wrapping our inner `init` function with `hk.lift` will register our inner `params` into the outer parameter scope."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9w7nuPF1C3lc",
        "outputId": "042d3e7c-cac7-4fc1-cfd0-495b69a193c2"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "FlatMap({\n",
              "  'inner/mlp/~/linear_0': FlatMap({'b': (300,), 'w': (100, 300)}),\n",
              "  'inner/mlp/~/linear_1': FlatMap({'b': (100,), 'w': (300, 100)}),\n",
              "  'mlp/~/linear_0': FlatMap({'b': (300,), 'w': (100, 300)}),\n",
              "  'mlp/~/linear_1': FlatMap({'b': (100,), 'w': (300, 100)}),\n",
              "})"
            ]
          },
          "execution_count": 5,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "def eval_shape_net(x):\n",
        "  net = hk.nets.MLP([300, 100])    # still side-effecting\n",
        "  init, apply = hk.transform(net)  # nested transform\n",
        "  params = hk.lift(init, name=\"inner\")(hk.next_rng_key(), x) # register parameters in outer module scope with name \"inner\"\n",
        "  output_shape = jax.eval_shape(apply, params, hk.next_rng_key(), x) # apply is a functionaly pure function and can be transformed!\n",
        "  out = net(x)                                                                                                                                    \n",
        "  return out, output_shape                                                                                                  \n",
        "                                                                                                                                                                        \n",
        "                                                                                                                                                                      \n",
        "init, apply = hk.transform(eval_shape_net)                                                                                                                  \n",
        "params = init(jax.random.PRNGKey(777), jnp.ones((100, 100)))\n",
        "apply(params, jax.random.PRNGKey(777), jnp.ones((100, 100)))\n",
        "jax.tree.map(lambda x: x.shape, params)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "90dfm3H6C3ld"
      },
      "source": [
        "## Using Haiku versions of JAX transforms\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "damGG_NkC3ld"
      },
      "source": [
        "Haiku also provides wrapped versions of some of the JAX functions for convenience. For example: `hk.grad`, `hk.vmap`, .... See https://dm-haiku.readthedocs.io/en/latest/api.html#jax-fundamentals for a full list of available functions.\n",
        "\n",
        "These wrappers apply the JAX function to a functionally pure version of the Haiku function, by doing the explicit state threading for you. They don't introduce an extra namescoping level like `lift` does."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "S-b4sFboC3ld"
      },
      "outputs": [],
      "source": [
        "def eval_shape_net(x):\n",
        "  net = hk.nets.MLP([300, 100])         # still side-effecting\n",
        "  output_shape = hk.eval_shape(net, x)  # hk.eval_shape threads through the Haiku state for you\n",
        "  out = net(x)                                                                                                                                    \n",
        "  return out, output_shape                                                                                                  \n",
        "                                                                                                                                                                        \n",
        "                                                                                                                                                                      \n",
        "init, apply = hk.transform(eval_shape_net)                                                                                                                  \n",
        "params = init(jax.random.PRNGKey(777), jnp.ones((100, 100)))\n",
        "out = apply(params, jax.random.PRNGKey(777), jnp.ones((100, 100)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "syzC6xUHCR13"
      },
      "source": [
        "## Summary\n",
        "\n",
        "To summarize, some good and bad examples of combining JAX transforms and Haiku modules:\n",
        "\n",
        "| What?                                         | Works?                        | Example                                               |\n",
        "|-----------------------------------------------|-------------------------------|-------------------------------------------------------|\n",
        "| vmapping outside a hk.transform               | ✔ yes!                        | jax.vmap(hk.transform(hk.nets.ResNet50))                        |\n",
        "| vmapping inside a hk.transform                | ✖ unexpected tracer error       | hk.transform(jax.vmap(hk.nets.ResNet50))                        |\n",
        "| vmapping a nested hk.transform (without lift) | ✖ inner state is not registered | hk.transform(jax.vmap(hk.transform(hk.nets.ResNet50)))          |\n",
        "| vmapping a nested hk.transform (with lift)    | ✔ yes!                          | hk.transform(jax.vmap(hk.lift(hk.transform(hk.nets.ResNet50)))) |\n",
        "| using hk.vmap                                 | ✔ yes!                          | hk.transform(hk.vmap(hk.nets.ResNet50))   \n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "transforms.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
